#!/usr/bin/env python3
import os
import sys
from statistics import variance

import numpy as np
import pandas as pd
from scipy import stats

CURDIR = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, os.path.join(CURDIR, "helpers"))

from pure_http_client import ClickHouseClient


# unpooled variance z-test for means of two samples
def scipy_anova(rvs):
    return stats.f_oneway(*rvs)


def test_and_check(rvs, n_groups, f_stat, p_value, precision=1e-2):
    client = ClickHouseClient()
    client.query("DROP TABLE IF EXISTS anova;")
    client.query("CREATE TABLE anova (left Float64, right UInt64) ENGINE = Memory;")
    for group in range(n_groups):
        client.query(
            f"""INSERT INTO anova VALUES {", ".join([f'({i},{group})' for i in rvs[group]])};"""
        )

    real = client.query_return_df(
        """SELECT roundBankers(a.1, 16) as f_stat, roundBankers(a.2, 16) as p_value FROM (SELECT anova(left, right) as a FROM anova) FORMAT TabSeparatedWithNames;"""
    )

    real_f_stat = real["f_stat"][0]
    real_p_value = real["p_value"][0]
    assert (
        abs(real_f_stat - np.float64(f_stat)) < precision
    ), f"clickhouse_f_stat {real_f_stat}, py_f_stat {f_stat}"
    assert (
        abs(real_p_value - np.float64(p_value)) < precision
    ), f"clickhouse_p_value {real_p_value}, py_p_value {p_value}"
    client.query("DROP TABLE IF EXISTS anova;")


def test_anova():
    n_groups = 3
    rvs = []
    loc = 0
    scale = 5
    size = 500
    for _ in range(n_groups):
        rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
        loc += 5
    f_stat, p_value = scipy_anova(rvs)
    test_and_check(rvs, n_groups, f_stat, p_value)

    n_groups = 6
    rvs = []
    loc = 0
    scale = 5
    size = 500
    for _ in range(n_groups):
        rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
    f_stat, p_value = scipy_anova(rvs)
    test_and_check(rvs, n_groups, f_stat, p_value)

    n_groups = 10
    rvs = []
    loc = 1
    scale = 2
    size = 100
    for _ in range(n_groups):
        rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
        loc += 1
        scale += 2
        size += 100
    f_stat, p_value = scipy_anova(rvs)
    test_and_check(rvs, n_groups, f_stat, p_value)

    n_groups = 20
    rvs = []
    loc = 0
    scale = 10
    size = 1100
    for _ in range(n_groups):
        rvs.append(np.round(stats.norm.rvs(loc=loc, scale=scale, size=size), 2))
        size -= 50
    f_stat, p_value = scipy_anova(rvs)
    test_and_check(rvs, n_groups, f_stat, p_value)


if __name__ == "__main__":
    test_anova()
    print("Ok.")
